Skip to content

Conversation

neo-alex
Copy link
Contributor

This PR fixes a Tensorflow trainer bug that arises when the first (flatten) input can be None in model.fit/model.evaluate (which is possible for optional inputs since PR #21548). Note: this PR makes the original code more robust but still assumes that at least one input is not None (to properly extract the batch size).

This fix was originally part of the bigger PR #21609 but is now pushed in a small dedicated PR as agreed with @hertschuh here.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Summary of Changes

Hello @neo-alex, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request enhances the robustness of the TensorFlow trainer by addressing an edge case where optional None inputs could cause issues during the calculation of batch-related metrics. It ensures that the training and testing steps can correctly determine the batch size even when some inputs are not provided, leading to more stable model operations.

Highlights

  • Tensorflow Trainer Robustness: Fixed a bug in the Tensorflow trainer where model.fit and model.evaluate would fail to correctly calculate sample_weight if the first flattened input was None. The fix ensures that the batch size is now extracted from the first available non-None input.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point in your pull request via creating an issue comment (i.e. comment on the pull request page) using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in issue comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This PR fixes a bug where the Tensorflow trainer would crash if the first input was None. The change correctly finds the first non-None input to determine the batch size. My review focuses on improving the robustness of this fix by handling the edge case where all inputs might be None, which would currently cause an unhandled exception. I've also pointed out that the new logic is duplicated and could be refactored for better maintainability.

Comment on lines +71 to +73
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This change correctly handles cases where the first input is None. However, it introduces a risk of a StopIteration error if all inputs in x are None. This can be difficult to debug, especially inside a tf.function.

A more robust approach would be to handle this edge case explicitly, for example by raising a ValueError with a clear message.

Also, this logic is duplicated in test_step. Consider extracting it into a private helper method to improve maintainability and ensure consistency.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

Comment on lines +101 to +103
sample_weight=tf.shape(
next(i for i in tree.flatten(x) if i is not None)
)[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

Similar to the train_step, this change is vulnerable to a StopIteration error if all inputs are None. Explicitly handling this edge case would make the code more robust and prevent potential runtime crashes.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

+1

@codecov-commenter
Copy link

codecov-commenter commented Aug 31, 2025

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 82.49%. Comparing base (4415fcc) to head (d25afcb).

Additional details and impacted files
@@           Coverage Diff           @@
##           master   #21630   +/-   ##
=======================================
  Coverage   82.49%   82.49%           
=======================================
  Files         572      572           
  Lines       57451    57451           
  Branches     8982     8982           
=======================================
  Hits        47395    47395           
  Misses       7760     7760           
  Partials     2296     2296           
Flag Coverage Δ
keras 82.30% <ø> (ø)
keras-jax 63.52% <ø> (ø)
keras-numpy 57.84% <ø> (ø)
keras-openvino 34.33% <ø> (ø)
keras-tensorflow 64.18% <ø> (ø)
keras-torch 63.75% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@JyotinderSingh
Copy link
Collaborator

Thanks for the fix! Could you please add a test for the same? It helps verify that the issue is actually fixed and we don't see regressions in the future.

Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR! Please add a simple unit test to test the fix.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants